### https://github.com/dfldylan/FluidMLP.git^fdb2d7d6f512b79e079d9e483e2cb411f66389f5

import os
import json

import tensorflow as tf
from dataset import DataSet
from model import Model

global_step = -1
config_json = r'./config.json'

if __name__ == '__main__':
    gpus = tf.config.list_physical_devices(device_type='GPU')
    # tf.config.set_visible_devices(devices=gpus[1:2], device_type='GPU')
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(device=gpu, enable=True)

    cfg = json.load(open(config_json))

    # load the dataset
    dataset = DataSet(cfg['datasets_folder'], dt=cfg['dt'])

    # initial and import(optional) the model
    os.makedirs(cfg['model_folder'], exist_ok=True)
    os.makedirs(cfg['tensorboard_folder'], exist_ok=True)

    model = Model(cfg['mlp_units'], cfg['query_ball_radius'], cfg['num_max_neighbor'], trainable=True)
    optimizer = tf.optimizers.Adam(learning_rate=cfg['learning_rate'])
    checkpoint = tf.train.Checkpoint(model=model)
    manager = tf.train.CheckpointManager(checkpoint, directory=cfg['model_folder'], max_to_keep=10,
                                         keep_checkpoint_every_n_hours=1)
    if tf.train.latest_checkpoint(cfg['model_folder']) is not None:
        latest_checkpoint = tf.train.latest_checkpoint(cfg['model_folder'])
        checkpoint.restore(latest_checkpoint)
        global_step = int(os.path.split(latest_checkpoint)[1].split(r'-')[1])
    summary_writer = tf.summary.create_file_writer(cfg['tensorboard_folder'])

    while True:
        print('step ' + str(global_step + 1), end='...')
        current_data = dataset.get_batch()
        _inputs = current_data[:, :7]
        _outputs = current_data[:, 7:10]
        current_loss = model.train(inputs=_inputs, outputs=_outputs, optimizer=optimizer)

        global_step += 1
        print(str(current_loss.numpy()), end='...')
        if global_step % 10 == 0:
            print('save model', end='...')
            manager.save(checkpoint_number=global_step)
            with summary_writer.as_default():
                tf.summary.scalar("loss", current_loss, step=global_step)

        print('ok!')
